from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import numpy as np
import torchvision
import cv2

from Utils.inference import get_max_preds
import config as cfg

def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
                                 file_name, nrow=8, padding=2):
    '''
    batch_image: [batch_size, channel, height, width]图片
    batch_joints: [batch_size, num_joints, 3],关键点
    batch_joints_vis: [batch_size, num_joints, 1],关键点是否可见,来自真实标签
    }
    '''
    #make_grid将若干副图像 拼成一幅图片
    grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    ndarr = ndarr.copy()

    nmaps = batch_image.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height = int(batch_image.size(2) + padding)
    width = int(batch_image.size(3) + padding)
    k = 0
    if cfg.data_type == 'mpii':
        for y in range(ymaps):
            for x in range(xmaps):
                if k >= nmaps:
                    break
                joints = batch_joints[k]
                joints_vis = batch_joints_vis[k]

                for index,(joint, joint_vis) in enumerate(zip(joints, joints_vis)):
                    joint[0] = x * width + padding + joint[0]
                    joint[1] = y * height + padding + joint[1]
                    font = cv2.FONT_HERSHEY_SIMPLEX  # 定义字体
                    if joint_vis[0] == 1:
                        cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [0,255,0], 2)
                        cv2.putText(ndarr,str(index+1),(int(joint[0]), int(joint[1])),font, 0.4, [0,255,0], 1)
                    elif joint_vis[0] == 0:
                        cv2.circle(ndarr,(int(joint[0]), int(joint[1])),2,[0,0,255],2)
                        cv2.putText(ndarr,str(index+1),(int(joint[0]), int(joint[1])),font,0.4,[0,0,255],1)
                k = k + 1
        cv2.imwrite(file_name, ndarr)
    elif cfg.data_type == 'vehicle':
        # keypoint_dict = {0:'左前灯',1:'右前灯',2:'左后灯',3:'右后灯',4:'左前顶',5:'左后顶',
        #                  6:'左后顶',7:'右后顶',8:'左前底',9:'右前底',10:'右后底',11:'左后底'}
        keypoint_dict = {0:'LF_light',1:'RF_light',2:'LB_light',3:'RB_light',4:'LF_roof',5:'RF_roof',
                         6:'RB_roof',7:'LB_roof',8:'LF_chassis',9:'RF_chassis',10:'RB_chassis',11:'LB_chassis'}
        for y in range(ymaps):
            for x in range(xmaps):
                if k >= nmaps:
                    break
                joints = batch_joints[k]
                joints_vis = batch_joints_vis[k]

                for index,(joint, joint_vis) in enumerate(zip(joints, joints_vis)):
                    joint[0] = x * width + padding + joint[0]
                    joint[1] = y * height + padding + joint[1]
                    font = cv2.FONT_HERSHEY_SIMPLEX  # 定义字体
                    if joint_vis[0] == 1:
                        cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [0,255,0], 2)
                        cv2.putText(ndarr,str(index),(int(joint[0]), int(joint[1])),font, 0.4, [0,255,0], 1)
                    elif joint_vis[0] == 0:
                        cv2.circle(ndarr,(int(joint[0]), int(joint[1])),2,[0,0,255],2)
                        cv2.putText(ndarr,str(index),(int(joint[0]), int(joint[1])),font,0.4,[0,0,255],1)
                k = k + 1
        cv2.imwrite(file_name, ndarr)

def save_batch_heatmaps(batch_image, batch_heatmaps, file_name,
                        normalize=True):
    '''
    batch_image: [batch_size, channel, height, width]
    batch_heatmaps: ['batch_size, num_joints, height, width]
    file_name: saved file name
    '''
    if normalize:
        batch_image = batch_image.clone()
        min = float(batch_image.min())
        max = float(batch_image.max())

        batch_image.add_(-min).div_(max - min + 1e-5)

    batch_size = batch_heatmaps.size(0)
    num_joints = batch_heatmaps.size(1)
    heatmap_height = batch_heatmaps.size(2)
    heatmap_width = batch_heatmaps.size(3)

    grid_image = np.zeros((batch_size*heatmap_height,
                           (num_joints+1)*heatmap_width,
                           3),
                          dtype=np.uint8)

    preds, maxvals = get_max_preds(batch_heatmaps.detach().cpu().numpy())

    for i in range(batch_size):
        image = batch_image[i].mul(255)\
                              .clamp(0, 255)\
                              .byte()\
                              .permute(1, 2, 0)\
                              .cpu().numpy()
        heatmaps = batch_heatmaps[i].mul(255)\
                                    .clamp(0, 255)\
                                    .byte()\
                                    .cpu().numpy()

        resized_image = cv2.resize(image,
                                   (int(heatmap_width), int(heatmap_height)))

        height_begin = heatmap_height * i
        height_end = heatmap_height * (i + 1)
        for j in range(num_joints):
            cv2.circle(resized_image,
                       (int(preds[i][j][0]), int(preds[i][j][1])),
                       1, [0, 0, 255], 1)
            heatmap = heatmaps[j, :, :]
            colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
            masked_image = colored_heatmap*0.7 + resized_image*0.3
            cv2.circle(masked_image,
                       (int(preds[i][j][0]), int(preds[i][j][1])),
                       1, [0, 0, 255], 1)

            width_begin = heatmap_width * (j+1)
            width_end = heatmap_width * (j+2)
            grid_image[height_begin:height_end, width_begin:width_end, :] = \
                masked_image
            # grid_image[height_begin:height_end, width_begin:width_end, :] = \
            #     colored_heatmap*0.7 + resized_image*0.3

        grid_image[height_begin:height_end, 0:heatmap_width, :] = resized_image

    cv2.imwrite(file_name, grid_image)


def save_debug_images(input, meta, target, joints_pred, output,
                      prefix):
    '''
    这个函数用来可视化验证结果。
    input:样本，Tensor:(batch_size,3,image_size,image_size)
    meta:样本对应的meta，dict:9
    target:真实标签，Tensor:(batch_size,num_points,heatmap_size,heatmap_size)
    joints_pred:模型预测的点，ndarray:(batch_size,num_points,2)
    output:Tensor:(batch_size,num_points,heatmap_size,heatmap_size)
    prefix:文件名。

    例如:
    input:Tensor:(4,3,256,256)
    meta:dict:9
    target:Tensor:(4,16,64,64)
    joints_pred:ndarray:(4,16,2)
    output:Tensor:(4,16,64,64)
    prefix:val_10。
    '''
    if cfg.save_batch_images_gt:
        save_batch_image_with_joints(
            input, meta['joints'], meta['joints_vis'],
            '{}_gt.jpg'.format(prefix)
        )
    if cfg.save_batch_images_pred:
        save_batch_image_with_joints(
            input, joints_pred, meta['joints_vis'],
            '{}_pred.jpg'.format(prefix)
        )
    if cfg.save_heatmaps_gt:
        save_batch_heatmaps(
            input, target, '{}_hm_gt.jpg'.format(prefix)
        )
    if cfg.save_heatmaps_pred:
        save_batch_heatmaps(
            input, output, '{}_hm_pred.jpg'.format(prefix)
        )
